Skip to content

Feat/gdn decode pooled#2521

Merged
yzh119 merged 13 commits intoflashinfer-ai:mainfrom
xutizhou:feat/gdn-decode-pooled
Mar 9, 2026
Merged

Feat/gdn decode pooled#2521
yzh119 merged 13 commits intoflashinfer-ai:mainfrom
xutizhou:feat/gdn-decode-pooled

Conversation

@xutizhou
Copy link
Copy Markdown
Contributor

@xutizhou xutizhou commented Feb 8, 2026

📌 Description

This PR adds pool-indexed (indirect) state access to the GDN decode kernel, enabling zero-copy integration with SGLang's state pool architecture.

Background: SGLang's State Pool Architecture

In SGLang, when serving linear attention models (like Qwen3-Next using Gated Delta Rule), we maintain a state pool to store recurrent states for all active requests:

ssm_states: [num_layers, pool_size, num_heads, head_dim, head_dim]

where pool_size = max_num_reqs (maximum concurrent requests).

Each active request has a req_pool_idx that maps it to a slot in this pool. The mapping is not contiguous - requests come and go, so indices can be scattered (e.g., a batch of 4 requests might have pool indices [3, 7, 12, 25]).

Motivation

The current GDN decode kernel expects state with shape [B, H, K, V] where B equals batch size and there's a 1:1 mapping (batch index i → state index i). To use it with SGLang's pool, we would need to:

  1. Gather states from pool indices before kernel call
  2. Run kernel on contiguous [B, H, K, V] state
  3. Scatter updated states back to pool indices

This adds 2 extra memory copy operations per decode step.

Changes

This PR adds a state_indices parameter for zero-copy pool access:

def gated_delta_rule_decode_pretranspose(
    q, k, v, beta,
    state,           # Can be [pool_size, H, K, V] instead of [B, H, K, V]
    state_indices,   # NEW: int32 tensor [B] mapping batch_idx -> pool_idx
    ...
)

When state_indices is provided:

  • Kernel uses indirect addressing: state[state_indices[batch_idx]] instead of state[batch_idx]
  • Negative indices (padding slots for CUDA graph) skip computation and write zeros to output
  • Eliminates gather/scatter overhead + host-side torch.where for padding (~37μs/call)

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

This PR is required for integrating FlashInfer's K-last GDN kernels into SGLang. The pool indexing feature allows SGLang to directly use its state pool without gather/scatter overhead.

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Optional pooled (indirect) state access via a new state_indices parameter enabling zero-copy pooled state handling; negative indices yield zeroed outputs.
  • Improvements

    • Kernel launch paths, grid sizing, and compiled-kernel caching differentiate pooled vs. non-pooled modes; APIs and docstrings updated to propagate pooling flags while maintaining compatibility.
  • Tests

    • New test suite validating pooled decode correctness, padding/negative-index behavior, state updates, and pooled vs. non-pooled equivalence.

…kernel negative index handling

Add use_pool_indexing constexpr to both small-batch and big-batch
pretranspose decode kernels, enabling zero-copy state access directly
from the pool via h0_indices, eliminating gather/scatter overhead.

Also handle negative pool indices (padding slots) inside the kernel:
blocks with negative indices skip computation and write zeros to output,
removing the need for host-side torch.where remap (~37us/call savings).

Combined effect: K-last decode is 4-5.6% faster than V-last at BS>=4.
…ingle function

Consolidate gated_delta_rule_decode_pretranspose_pooled into
gated_delta_rule_decode_pretranspose by adding an optional state_indices
parameter. When state_indices is provided, the kernel uses pool-indexed
(zero-copy) mode; otherwise it uses direct 1:1 batch-to-state mapping.

This eliminates ~175 lines of duplicated Python wrapper code while the
underlying CUDA kernels remain unchanged. The compiled kernel cache key
now includes pool_size and use_pool_indexing to ensure correct cache
separation between the two modes.
When using pool indexing (state_indices), a non-contiguous state tensor
could silently produce incorrect results because the kernel assumes
contiguous memory layout for pointer arithmetic. Add an explicit
assertion to catch this early.
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Feb 8, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds pooled (indirect) state access to gated-delta-rule decode in flashinfer/gdn_decode.py via state_indices / use_pool_indexing; threads pool-aware parameters through kernel signatures and compiled-kernel cache keys, adjusts grid sizing and h0/state preparation and padding, and adds tests validating pooled decode semantics.

Changes

Cohort / File(s) Summary
GDN decode kernels & API
flashinfer/gdn_decode.py
Introduce state_indices: Optional[torch.Tensor] and use_pool_indexing: cutlass.Constexpr[bool] through public APIs and kernel entrypoints; adapt state addressing (pool-indexed vs direct), zero-padding semantics for negative indices, grid batch sizing (grid_batch = B * HV), and h0_source preparation when pooling enabled.
Kernel launch, compilation & caching
flashinfer/gdn_decode.py (helper functions)
Add _get_compiled_decode_kernel_nontranspose, extend _get_compiled_decode_kernel cache key to include pool_size and use_pool_indexing, maintain separate caches per pooling mode; update run_gdn_decode_kernel_*_pretranspose wrappers to accept/forward use_pool_indexing (and optional stream).
Nontranspose & MTP paths
flashinfer/gdn_decode.py (nontranspose variants)
Update gdn_decode_kernel_*_nontranspose paths to accept use_pool_indexing and state_indices; change state copy-back logic to in-place pool updates when pooling is used, preserve existing reshaping path for non-pooled mode.
Tests — pooled decode
tests/gdn/test_decode_pooled.py
New test suite validating pooled decode behavior: negative-index padding, sglang sentinel pattern, pooled vs non-pooled equivalence, and all-padding cases; compares CUDA kernel outputs/state updates against Python reference and gates by CUDA SM capability.

Sequence Diagram(s)

sequenceDiagram
    participant Client
    participant API as "gated_delta_rule_decode_pretranspose"
    participant Cache as "KernelCache/_get_compiled_decode_kernel"
    participant CUDA as "CUDA Kernel"
    participant Pool as "State Pool / h0_source"

    Client->>API: call gated_delta_rule_decode_pretranspose(state, state_indices?, ...)
    API->>Cache: request compiled kernel (pool_size, use_pool_indexing, ...)
    Cache-->>API: compiled kernel handle
    API->>CUDA: launch kernel with grid shaped by grid_batch (B*HV) and use_pool_indexing
    CUDA->>Pool: read state via state_indices (indirect) or direct mapping
    CUDA-->>API: write outputs and updated state slots (respecting negative-index padding)
    API-->>Client: return outputs and updated state buffer
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested labels

v0.6.2, model: qwen3-next

Suggested reviewers

  • cyx-6
  • bkryu
  • nvmbreughe
  • kahyunnam
  • jimmyzho
  • yzh119

Poem

🐰 I hopped through kernels with indices bright,

Pooled slots pointed, outputs just right.
Negative pad? I left them alone,
States stayed tidy in their stone,
A happy rabbit, debugging by moonlight.

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 inconclusive)

Check name Status Explanation Resolution
Title check ❓ Inconclusive The title is vague and uses non-descriptive phrasing 'Feat/gdn decode pooled' that lacks specifics about the actual feature implementation. Consider a more specific title like 'Add pool-indexed state access to GDN decode kernel' to clearly convey the main change.
✅ Passed checks (4 passed)
Check name Status Explanation
Description check ✅ Passed The PR description is comprehensive and well-structured, covering background, motivation, changes, and including all required checklist sections with appropriate completion status.
Linked Issues check ✅ Passed The PR explicitly links to related SGLang integration work (sgl-project/sglang#18361) and mentions a follow-up port to f16 state kernel (PR #2634).
Out of Scope Changes check ✅ Passed All changes are directly related to adding pool-indexed state access to GDN decode, with no unrelated modifications present.
Docstring Coverage ✅ Passed Docstring coverage is 92.86% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @xutizhou, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a significant optimization to the Gated Delta Rule (GDN) decode kernels by implementing a pooled indexing mechanism. This enhancement allows the kernels to directly access and update a shared state pool using indirect indices, thereby enabling a zero-copy mode that bypasses intermediate data transfers. The changes involve modifying kernel interfaces, adapting internal state management, and updating the kernel compilation and launch logic to support this more efficient state handling strategy.

Highlights

  • Pooled Indexing Feature: Introduced a new use_pool_indexing parameter and state_indices tensor to enable a zero-copy pooled mode for Gated Delta Rule (GDN) decode kernels. This allows the kernels to read and write state directly from a shared state pool using indirect indexing, eliminating the need for explicit gather/scatter operations.
  • Padding Handling: When use_pool_indexing is active, negative values in state_indices are treated as padding slots. For these slots, the kernel skips computation and writes zeros to the output, ensuring correct behavior for inactive batch elements.
  • Kernel Signature and Logic Updates: The gdn_decode_kernel_small_batch_pretranspose and gdn_decode_kernel_big_batch_pretranspose functions now accept the use_pool_indexing boolean. Their internal logic has been modified to compute state_idx based on h0_indices (when pooled indexing is enabled) and to conditionally execute the main computation loop based on pool_idx >= 0.
  • Wrapper Function Enhancements: The gated_delta_rule_decode_pretranspose function now accepts an optional state_indices tensor. It includes updated input validation for state shape and state_indices, adjusts the h0_source reshaping for pooled state, and modifies the kernel launch grid size calculation to use the actual batch size (B * HV) instead of the potentially larger pool size.
  • Kernel Caching: The _get_compiled_decode_kernel cache key has been extended to include pool_size and use_pool_indexing. This ensures that different configurations (pooled vs. non-pooled, and different pool sizes) result in separately compiled CUDA code, optimizing performance for each specific use case.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • flashinfer/gdn_decode.py
    • Added use_pool_indexing parameter to gdn_decode_kernel_small_batch_pretranspose and gdn_decode_kernel_big_batch_pretranspose.
    • Implemented conditional state indexing logic within kernels based on use_pool_indexing and h0_indices.
    • Added logic to write zeros to output for padding slots (negative pool_idx) in kernels.
    • Modified run_gdn_decode_kernel_small_batch_pretranspose and run_gdn_decode_kernel_big_batch_pretranspose to pass use_pool_indexing and adjust grid size for pooled mode.
    • Updated _get_compiled_decode_kernel cache key to include pool_size and use_pool_indexing.
    • Added state_indices parameter to gated_delta_rule_decode_pretranspose.
    • Updated docstrings for gated_delta_rule_decode_pretranspose to describe pooled mode and state_indices.
    • Adjusted state shape validation and h0_source reshaping in gated_delta_rule_decode_pretranspose for pooled state.
    • Modified the condition for copying state back to handle in-place updates in pooled mode.
    • Removed unused print statements and a calculation.
    • Relocated the _get_compiled_decode_kernel_nontranspose function definition.
Activity
  • The pull request was opened by xutizhou with the title 'Feat/gdn decode pooled', indicating a new feature for GDN decode with pooling capabilities.
  • The changes introduce a new use_pool_indexing parameter and state_indices argument across several functions, suggesting a significant architectural change to how state is managed.
  • The modifications include updates to kernel logic, function signatures, and caching mechanisms to support the new pooled indexing feature.
  • The PR description uses a standard template, implying that the author has followed the project's contribution guidelines, but does not contain additional specific comments or context beyond the template sections.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces pooled decoding for the Gated Delta Rule, a significant performance optimization for inference that avoids state reallocation. The changes are well-implemented, adding a use_pool_indexing flag and state_indices to map batch items to a state pool. The kernel logic, launcher functions, and public API have been updated accordingly, including necessary input validation and caching mechanism adjustments. I've identified one minor opportunity for code simplification to reduce duplication. Overall, this is a solid and valuable contribution.

o[(i_n, i_t, i_hv, tidx)] = sOutput[tidx]
else:
# Padding slot: write zeros to output
start_v_tiles = batch_inner * num_v_tiles_per_block
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The variable start_v_tiles is calculated in both branches of the if pool_idx >= 0 condition (here and at line 248). To avoid code duplication and improve clarity, you can calculate it once before the if block (e.g., at line 221).

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@flashinfer/gdn_decode.py`:
- Around line 1042-1052: Add a bounds check for state_indices when
use_pool_indexing is true: after verifying contiguity, shape and dtype, assert
that all values are >= 0 and < pool_size (e.g., check state_indices.min() >= 0
and state_indices.max() < pool_size or use torch.any to detect OOB) to prevent
GPU OOB accesses; gate this check behind a debug flag (or a cheap runtime
condition like torch.is_grad_enabled() or an explicit debug parameter) so it
only runs in debug builds.
🧹 Nitpick comments (1)
flashinfer/gdn_decode.py (1)

237-240: Non-English comment in kernel code.

# V 方向分 tiles should be translated to English (e.g., # Tile along V dimension) for codebase consistency and accessibility.

Same issue at line 526.

Comment on lines +1042 to +1052
if use_pool_indexing:
assert state.is_contiguous(), (
"state must be contiguous when using pool indexing (state_indices); "
"a non-contiguous tensor may silently produce incorrect results"
)
assert state_indices.shape == (B,), (
f"Expected state_indices shape [{B}], got {state_indices.shape}"
)
assert state_indices.dtype == torch.int32, (
f"state_indices must be int32, got {state_indices.dtype}"
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Missing bounds validation: state_indices values are not checked against pool_size.

The assertion block validates shape and dtype of state_indices, but doesn't verify that non-negative values are < pool_size. An out-of-range index would silently cause an OOB global memory access in the kernel, potentially corrupting memory or causing a GPU fault.

Consider adding a debug-mode bounds check:

🛡️ Proposed fix
         assert state_indices.dtype == torch.int32, (
             f"state_indices must be int32, got {state_indices.dtype}"
         )
+        # Validate index bounds (non-negative indices must be < pool_size)
+        valid_mask = state_indices >= 0
+        if valid_mask.any():
+            max_idx = state_indices[valid_mask].max().item()
+            assert max_idx < pool_size, (
+                f"state_indices contains index {max_idx} >= pool_size={pool_size}"
+            )

This adds a small overhead, so you may want to gate it behind a debug flag or torch.is_grad_enabled() check.

🤖 Prompt for AI Agents
In `@flashinfer/gdn_decode.py` around lines 1042 - 1052, Add a bounds check for
state_indices when use_pool_indexing is true: after verifying contiguity, shape
and dtype, assert that all values are >= 0 and < pool_size (e.g., check
state_indices.min() >= 0 and state_indices.max() < pool_size or use torch.any to
detect OOB) to prevent GPU OOB accesses; gate this check behind a debug flag (or
a cheap runtime condition like torch.is_grad_enabled() or an explicit debug
parameter) so it only runs in debug builds.

@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Feb 9, 2026

@xutizhou can you explain what this PR is about? e.g. adding descriptions about what pool is in GDN.

@xutizhou
Copy link
Copy Markdown
Contributor Author

@xutizhou can you explain what this PR is about? e.g. adding descriptions about what pool is in GDN.

updated in the description.

@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Feb 17, 2026

Hi @xutizhou can you also port your work to the f16 state kernel that have already been merged in #2498 ?

@xutizhou
Copy link
Copy Markdown
Contributor Author

Hi @xutizhou can you also port your work to the f16 state kernel that have already been merged in #2498 ?

in this pr or open a new pr?

xutizhou added a commit to xutizhou/flashinfer that referenced this pull request Feb 25, 2026
…lashinfer-ai#2521

Revert gdn_decode.py to base — the state_indices parameter and pool
validation in gated_delta_rule_decode_pretranspose belong to PR flashinfer-ai#2521.
This PR now only contains BF16 CuTe DSL kernel pool indexing changes
in gdn_decode_bf16_state.py.

AI-assisted (Claude)
Add comprehensive tests for gated_delta_rule_decode_pretranspose with
pool indexing (state_indices parameter):

- Test 1: Pooled decode with negative indices (~20% padding)
- Test 2: sglang forward_decode calling pattern (unique indices + PAD_SLOT_ID)
- Test 3: Pooled vs non-pooled equivalence with identity mapping
- Test 4: All-padding batch (output zeros, pool state unchanged)

All tests verify output and state against per-sample reference
implementation. AI-assisted.
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
tests/gdn/test_decode_pooled.py (1)

40-41: Generalize the SM gate to avoid skipping future architectures.

Line 40 uses a fixed allowlist ([9, 10, 11, 12]), which will start skipping newer SM majors even when they should be valid. A lower-bound check is safer.

Suggested patch
-    if cc[0] not in [9, 10, 11, 12]:
-        pytest.skip(f"GDN decode requires SM90+ or SM100+, but got SM{cc[0]}{cc[1]}")
+    if cc[0] < 9:
+        pytest.skip(f"GDN decode requires SM90+, but got SM{cc[0]}{cc[1]}")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gdn/test_decode_pooled.py` around lines 40 - 41, The test currently
hardcodes an allowlist of SM majors using cc[0] not in [9, 10, 11, 12], causing
future SM majors to be skipped; change this to a lower-bound check (e.g., if
cc[0] < 9) so any SM major >=9 is accepted, and update the pytest.skip message
to reflect the minimum required SM (use cc[0] to report actual SM and a message
like "requires SM9+"). Keep the check and message near the same spot where cc
and pytest.skip are used in tests/gdn/test_decode_pooled.py.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/gdn/test_decode_pooled.py`:
- Around line 229-230: The test simulates a sentinel-slot layout but samples
valid slot indices from [0, pool_size-1], allowing 0 to be treated as a real
request; update the sampling to exclude the sentinel by drawing cache indices
from the range 1..pool_size inclusive (instead of 0..pool_size-1) so slot 0
remains the sentinel. Change the sampling logic that builds cache_indices (and
any related uses around the block referencing pool_size and PAD_SLOT_ID) to use
the corrected range (e.g., start at 1 and go to pool_size) in both places
flagged (around the code that constructs cache_indices and the subsequent
samples).

---

Nitpick comments:
In `@tests/gdn/test_decode_pooled.py`:
- Around line 40-41: The test currently hardcodes an allowlist of SM majors
using cc[0] not in [9, 10, 11, 12], causing future SM majors to be skipped;
change this to a lower-bound check (e.g., if cc[0] < 9) so any SM major >=9 is
accepted, and update the pytest.skip message to reflect the minimum required SM
(use cc[0] to report actual SM and a message like "requires SM9+"). Keep the
check and message near the same spot where cc and pytest.skip are used in
tests/gdn/test_decode_pooled.py.

ℹ️ Review info

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 2799807 and a92052a.

📒 Files selected for processing (1)
  • tests/gdn/test_decode_pooled.py

Comment on lines +229 to +230
# - Full pool passed as state (pool_size+1 slots, slot 0 is sentinel)
# - cache_indices from scheduler, with PAD_SLOT_ID = -1 for padding
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Sentinel-slot simulation is inconsistent with the stated SGLang layout.

Line 229 says slot 0 is sentinel, but Line 258 samples valid indices from [0, pool_size-1], so 0 can be used as a real request slot. This weakens the “exactly like SGLang” guarantee and can hide off-by-one mapping bugs.

Suggested patch
-    cache_indices_int64 = torch.randperm(pool_size, device=device)[:num_valid].to(
+    cache_indices_int64 = (torch.randperm(pool_size, device=device)[:num_valid] + 1).to(
         torch.int64
     )

Also applies to: 257-260

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gdn/test_decode_pooled.py` around lines 229 - 230, The test simulates a
sentinel-slot layout but samples valid slot indices from [0, pool_size-1],
allowing 0 to be treated as a real request; update the sampling to exclude the
sentinel by drawing cache indices from the range 1..pool_size inclusive (instead
of 0..pool_size-1) so slot 0 remains the sentinel. Change the sampling logic
that builds cache_indices (and any related uses around the block referencing
pool_size and PAD_SLOT_ID) to use the corrected range (e.g., start at 1 and go
to pool_size) in both places flagged (around the code that constructs
cache_indices and the subsequent samples).

…ode-pooled

# Conflicts:
#	flashinfer/gdn_decode.py
@xutizhou
Copy link
Copy Markdown
Contributor Author

Hi @xutizhou can you also port your work to the f16 state kernel that have already been merged in #2498 ?

Ported in a new pr #2634

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/gdn_decode.py (1)

1067-1103: ⚠️ Potential issue | 🔴 Critical

Critical: bf16 fast-path ignores pooled indexing semantics.

When state_indices is provided and bf16 fast-path conditions match, the code still enters _gated_delta_rule_gdn_decode_klast_bf16_state (Line 1067-1073), but that call has no index mapping and therefore bypasses pooled indirection.

🔧 Proposed fix
     use_gdn_decode_klast_bf16_state = (
         _GDN_DECODE_KLAST_BF16_STATE_AVAILABLE
+        and not use_pool_indexing
         and state.dtype == torch.bfloat16
         and T in (1, 2, 3, 4)
         and K == 128
         and V == 128
     )
+    if use_pool_indexing and state.dtype == torch.bfloat16:
+        raise NotImplementedError(
+            "state_indices with bfloat16 state is not supported in this decode path yet."
+        )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 1067 - 1103, The bf16 fast-path
(_gated_delta_rule_gdn_decode_klast_bf16_state) currently ignores pooled
indexing and thus bypasses state indirection when state_indices is set; fix by
detecting state_indices before taking the bf16 fast-path and either (a) disable
the bf16 fast-path (i.e., set use_gdn_decode_klast_bf16_state to False) so the
standard path that respects pooled indexing runs, or (b) materialize a gathered
initial_state_source = state.index_select(0, state_indices) (or equivalent
gather) and pass that gathered tensor as initial_state_source into
_gated_delta_rule_gdn_decode_klast_bf16_state so pooled indexing is honored;
ensure you reference state_indices, state, initial_state_source, and
_gated_delta_rule_gdn_decode_klast_bf16_state when making the change and
preserve the existing output handling logic.
♻️ Duplicate comments (1)
flashinfer/gdn_decode.py (1)

1053-1064: ⚠️ Potential issue | 🔴 Critical

Critical: validate state_indices upper bound before kernel launch.

state_indices shape/dtype are checked, but non-negative values are not constrained to < pool_size. This can trigger OOB GPU memory access in pooled mode (Line 1062 onward).

🛡️ Proposed fix
     if use_pool_indexing:
         assert state.is_contiguous(), (
             "state must be contiguous when using pool indexing (state_indices); "
             "a non-contiguous tensor may silently produce incorrect results"
         )
         assert state_indices.shape == (B,), (
             f"Expected state_indices shape [{B}], got {state_indices.shape}"
         )
         assert state_indices.dtype == torch.int32, (
             f"state_indices must be int32, got {state_indices.dtype}"
         )
+        valid_mask = state_indices >= 0
+        if torch.any(valid_mask):
+            max_idx = int(state_indices[valid_mask].max().item())
+            assert max_idx < pool_size, (
+                f"state_indices contains index {max_idx} >= pool_size={pool_size}"
+            )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 1053 - 1064, The code validates
state_indices shape/dtype but misses verifying bounds, which can cause OOB GPU
access when using pooled mode; in the pooled-path (where use_pool_indexing is
True and before the kernel launch that uses state_indices) add a validation that
all values in state_indices are >= 0 and < pool_size (and keep the existing
dtype/int checks), e.g. check torch.all(state_indices >= 0) and
torch.all(state_indices < pool_size) (or equivalent CPU-side/min/max check) and
raise/assert with a clear message referencing state_indices and pool_size so
invalid indices are caught before the kernel is launched.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Outside diff comments:
In `@flashinfer/gdn_decode.py`:
- Around line 1067-1103: The bf16 fast-path
(_gated_delta_rule_gdn_decode_klast_bf16_state) currently ignores pooled
indexing and thus bypasses state indirection when state_indices is set; fix by
detecting state_indices before taking the bf16 fast-path and either (a) disable
the bf16 fast-path (i.e., set use_gdn_decode_klast_bf16_state to False) so the
standard path that respects pooled indexing runs, or (b) materialize a gathered
initial_state_source = state.index_select(0, state_indices) (or equivalent
gather) and pass that gathered tensor as initial_state_source into
_gated_delta_rule_gdn_decode_klast_bf16_state so pooled indexing is honored;
ensure you reference state_indices, state, initial_state_source, and
_gated_delta_rule_gdn_decode_klast_bf16_state when making the change and
preserve the existing output handling logic.

---

Duplicate comments:
In `@flashinfer/gdn_decode.py`:
- Around line 1053-1064: The code validates state_indices shape/dtype but misses
verifying bounds, which can cause OOB GPU access when using pooled mode; in the
pooled-path (where use_pool_indexing is True and before the kernel launch that
uses state_indices) add a validation that all values in state_indices are >= 0
and < pool_size (and keep the existing dtype/int checks), e.g. check
torch.all(state_indices >= 0) and torch.all(state_indices < pool_size) (or
equivalent CPU-side/min/max check) and raise/assert with a clear message
referencing state_indices and pool_size so invalid indices are caught before the
kernel is launched.

ℹ️ Review info

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between a92052a and 85b261b.

📒 Files selected for processing (1)
  • flashinfer/gdn_decode.py


# Partition for load
thr_copy_load = tiled_copy_load.get_slice(tidx)
# V 方向分 tiles
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please fix

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Mar 4, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !371 has been created, and the CI pipeline #45304647 is currently running. I'll report back once the pipeline job completes.

@yzh119 yzh119 added the run-ci label Mar 4, 2026
…isted)

Replace from_dlpack(h0_source) with make_fake_compact_tensor using
cute.sym_int() for the pool_batch dimension, so a single compiled
kernel handles any pool_size at runtime. stride_order=(2,1,0) ensures
row-major layout with concrete strides for cp.async alignment.

Benchmarks show zero performance regression vs compile-time shape:
  from_dlpack: 0.0306ms median (bs=32, pool=128)
  sym_int:     0.0307ms median (bs=32, pool=128)
@xutizhou xutizhou force-pushed the feat/gdn-decode-pooled branch from 5f15fa8 to 308ad6b Compare March 4, 2026 08:55
@kaixih kaixih mentioned this pull request Mar 4, 2026
40 tasks
@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #45304647: 6/20 passed

@xutizhou
Copy link
Copy Markdown
Contributor Author

xutizhou commented Mar 5, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

@xutizhou is not authorized to trigger this CI job. cc: @yzh119, @sricketts, @yongwww

xutizhou added 2 commits March 5, 2026 17:10
…om PR flashinfer-ai#2619

Resolve merge conflicts with upstream main which added pool+indices support
via the bf16 fast path (PR flashinfer-ai#2619). Key changes:
- Adopt upstream API naming: state_indices -> initial_state_indices,
  state pool passed via initial_state param
- Update test_decode_pooled.py to use new API with bf16 state
- Skip negative-index tests (bf16 kernel does not support them yet)
- Legacy f32 CuTe DSL path preserved for non-pool usage

AI-assisted merge resolution.
Remove 'assert not use_pool' from f32 path — the sym_int approach
already handles arbitrary pool_size at runtime with zero overhead.
Tests 1/2/4 use f32 state with negative indices (padding support).
Test 3 uses bf16 state (routed to bf16 fast path).

All 23 pooled decode tests pass. AI-assisted.
@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Mar 5, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !371 has been updated with latest changes, and the CI pipeline #45437134 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[SUCCESS] Pipeline #45437134: 10/20 passed

@xutizhou xutizhou requested a review from kaixih as a code owner March 7, 2026 04:09
@xutizhou xutizhou force-pushed the feat/gdn-decode-pooled branch from 29a2a7f to 01f091d Compare March 7, 2026 04:38
Merge test_decode_pooled.py into test_decode_delta_rule.py with:
- state_dtype parametrize (bf16 + f32) for pool test
- negative indices and all-padding tests (f32 state only)
- per-sample Python reference to avoid JIT cache contamination
- float32 dt_bias matching SGLang production usage
- pytestmark skip preserved to match upstream main CI
@xutizhou xutizhou force-pushed the feat/gdn-decode-pooled branch from 01f091d to e5df67c Compare March 7, 2026 04:56
@yzh119 yzh119 merged commit bcdf8d8 into flashinfer-ai:main Mar 9, 2026
29 of 30 checks passed
brandonmmusic-max pushed a commit to brandonmmusic-max/flashinfer that referenced this pull request Mar 9, 2026
## 📌 Description

This PR adds pool-indexed (indirect) state access to the GDN decode
kernel, enabling zero-copy integration with SGLang's state pool
architecture.

### Background: SGLang's State Pool Architecture

In SGLang, when serving linear attention models (like Qwen3-Next using
Gated Delta Rule), we maintain a **state pool** to store recurrent
states for all active requests:

`ssm_states: [num_layers, pool_size, num_heads, head_dim, head_dim]`

where `pool_size` = `max_num_reqs` (maximum concurrent requests).

Each active request has a `req_pool_idx` that maps it to a slot in this
pool. The mapping is **not contiguous** - requests come and go, so
indices can be scattered (e.g., a batch of 4 requests might have pool
indices `[3, 7, 12, 25]`).

### Motivation

The current GDN decode kernel expects state with shape `[B, H, K, V]`
where B equals batch size and there's a 1:1 mapping (batch index i →
state index i). To use it with SGLang's pool, we would need to:

1. **Gather** states from pool indices before kernel call
2. Run kernel on contiguous `[B, H, K, V]` state
3. **Scatter** updated states back to pool indices

This adds 2 extra memory copy operations per decode step.

### Changes

This PR adds a `state_indices` parameter for **zero-copy pool access**:

```python
def gated_delta_rule_decode_pretranspose(
    q, k, v, beta,
    state,           # Can be [pool_size, H, K, V] instead of [B, H, K, V]
    state_indices,   # NEW: int32 tensor [B] mapping batch_idx -> pool_idx
    ...
)
```

When `state_indices` is provided:
- Kernel uses indirect addressing: `state[state_indices[batch_idx]]`
instead of `state[batch_idx]`
- Negative indices (padding slots for CUDA graph) skip computation and
write zeros to output
- Eliminates gather/scatter overhead + host-side `torch.where` for
padding (~37μs/call)

## 🔍 Related Issues

-
[sgl-project/sglang#18361](sgl-project/sglang#18361)
- FlashInfer K-last GDN integration into SGLang

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

This PR is required for integrating FlashInfer's K-last GDN kernels into
SGLang. The pool indexing feature allows SGLang to directly use its
state pool without gather/scatter overhead.

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [ ] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [ ] I have installed the hooks with `pre-commit install`.
- [ ] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Optional pooled (indirect) state access via a new state_indices
parameter enabling zero-copy pooled state handling; negative indices
yield zeroed outputs.

* **Improvements**
* Kernel launch paths, grid sizing, and compiled-kernel caching
differentiate pooled vs. non-pooled modes; APIs and docstrings updated
to propagate pooling flags while maintaining compatibility.

* **Tests**
* New test suite validating pooled decode correctness,
padding/negative-index behavior, state updates, and pooled vs.
non-pooled equivalence.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
# Build h0_source: [pool_size*HV, V, K] for kernel
if use_pool:
pool_size = initial_state.shape[0]
assert initial_state.is_contiguous(), (
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we consider to support non-contiguous state?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in which situation do we need non-contiguous state?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vLLM uses non-contiguous state

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vLLM uses non-contiguous state

For non-contiguous states, we should be able to compute the true indices using strides. Once the assert is removed, it can also work with our kernel.

frankwang28 pushed a commit to frankwang28/flashinfer that referenced this pull request Mar 18, 2026
## 📌 Description

This PR adds pool-indexed (indirect) state access to the GDN decode
kernel, enabling zero-copy integration with SGLang's state pool
architecture.

### Background: SGLang's State Pool Architecture

In SGLang, when serving linear attention models (like Qwen3-Next using
Gated Delta Rule), we maintain a **state pool** to store recurrent
states for all active requests:

`ssm_states: [num_layers, pool_size, num_heads, head_dim, head_dim]`

where `pool_size` = `max_num_reqs` (maximum concurrent requests).

Each active request has a `req_pool_idx` that maps it to a slot in this
pool. The mapping is **not contiguous** - requests come and go, so
indices can be scattered (e.g., a batch of 4 requests might have pool
indices `[3, 7, 12, 25]`).

### Motivation

The current GDN decode kernel expects state with shape `[B, H, K, V]`
where B equals batch size and there's a 1:1 mapping (batch index i →
state index i). To use it with SGLang's pool, we would need to:

1. **Gather** states from pool indices before kernel call
2. Run kernel on contiguous `[B, H, K, V]` state  
3. **Scatter** updated states back to pool indices

This adds 2 extra memory copy operations per decode step.

### Changes

This PR adds a `state_indices` parameter for **zero-copy pool access**:

```python
def gated_delta_rule_decode_pretranspose(
    q, k, v, beta,
    state,           # Can be [pool_size, H, K, V] instead of [B, H, K, V]
    state_indices,   # NEW: int32 tensor [B] mapping batch_idx -> pool_idx
    ...
)
```

When `state_indices` is provided:
- Kernel uses indirect addressing: `state[state_indices[batch_idx]]`
instead of `state[batch_idx]`
- Negative indices (padding slots for CUDA graph) skip computation and
write zeros to output
- Eliminates gather/scatter overhead + host-side `torch.where` for
padding (~37μs/call)


## 🔍 Related Issues

-
[sgl-project/sglang#18361](sgl-project/sglang#18361)
- FlashInfer K-last GDN integration into SGLang

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

This PR is required for integrating FlashInfer's K-last GDN kernels into
SGLang. The pool indexing feature allows SGLang to directly use its
state pool without gather/scatter overhead.


## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [ ] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [ ] I have installed the hooks with `pre-commit install`.
- [ ] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Optional pooled (indirect) state access via a new state_indices
parameter enabling zero-copy pooled state handling; negative indices
yield zeroed outputs.

* **Improvements**
* Kernel launch paths, grid sizing, and compiled-kernel caching
differentiate pooled vs. non-pooled modes; APIs and docstrings updated
to propagate pooling flags while maintaining compatibility.

* **Tests**
* New test suite validating pooled decode correctness,
padding/negative-index behavior, state updates, and pooled vs.
non-pooled equivalence.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
ameynaik-hub pushed a commit to ameynaik-hub/flashinfer that referenced this pull request Mar 18, 2026
## 📌 Description

This PR adds pool-indexed (indirect) state access to the GDN decode
kernel, enabling zero-copy integration with SGLang's state pool
architecture.

### Background: SGLang's State Pool Architecture

In SGLang, when serving linear attention models (like Qwen3-Next using
Gated Delta Rule), we maintain a **state pool** to store recurrent
states for all active requests:

`ssm_states: [num_layers, pool_size, num_heads, head_dim, head_dim]`

where `pool_size` = `max_num_reqs` (maximum concurrent requests).

Each active request has a `req_pool_idx` that maps it to a slot in this
pool. The mapping is **not contiguous** - requests come and go, so
indices can be scattered (e.g., a batch of 4 requests might have pool
indices `[3, 7, 12, 25]`).

### Motivation

The current GDN decode kernel expects state with shape `[B, H, K, V]`
where B equals batch size and there's a 1:1 mapping (batch index i →
state index i). To use it with SGLang's pool, we would need to:

1. **Gather** states from pool indices before kernel call
2. Run kernel on contiguous `[B, H, K, V]` state
3. **Scatter** updated states back to pool indices

This adds 2 extra memory copy operations per decode step.

### Changes

This PR adds a `state_indices` parameter for **zero-copy pool access**:

```python
def gated_delta_rule_decode_pretranspose(
    q, k, v, beta,
    state,           # Can be [pool_size, H, K, V] instead of [B, H, K, V]
    state_indices,   # NEW: int32 tensor [B] mapping batch_idx -> pool_idx
    ...
)
```

When `state_indices` is provided:
- Kernel uses indirect addressing: `state[state_indices[batch_idx]]`
instead of `state[batch_idx]`
- Negative indices (padding slots for CUDA graph) skip computation and
write zeros to output
- Eliminates gather/scatter overhead + host-side `torch.where` for
padding (~37μs/call)

## 🔍 Related Issues

-
[sgl-project/sglang#18361](sgl-project/sglang#18361)
- FlashInfer K-last GDN integration into SGLang

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

This PR is required for integrating FlashInfer's K-last GDN kernels into
SGLang. The pool indexing feature allows SGLang to directly use its
state pool without gather/scatter overhead.

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [ ] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [ ] I have installed the hooks with `pre-commit install`.
- [ ] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Optional pooled (indirect) state access via a new state_indices
parameter enabling zero-copy pooled state handling; negative indices
yield zeroed outputs.

* **Improvements**
* Kernel launch paths, grid sizing, and compiled-kernel caching
differentiate pooled vs. non-pooled modes; APIs and docstrings updated
to propagate pooling flags while maintaining compatibility.

* **Tests**
* New test suite validating pooled decode correctness,
padding/negative-index behavior, state updates, and pooled vs.
non-pooled equivalence.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
kahyunnam pushed a commit that referenced this pull request Mar 25, 2026
<!-- .github/pull_request_template.md -->

## 📌 Description

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->
vllm uses non-contiguous state for gdn. Make flashinfer also support it
## 🔍 Related Issues
#2521
#2687

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants